# -*- coding: utf-8 -*-
'''
假设输入的input_ids 为：[[101,5,6,7,102,3,9,102],[101,10,11,102,7,9,102,0],,[101,10,102,7,9,102,0,0]]
其中：101代表[CLS],102代表[SEP],0代表[PAD]
对应的 token_type_ids_for_mask 为：[[1,1,1,1,1,0,0,0],[1,1,1,1,0,0,0,-1],[1,1,1,0,0,0,-1,-1]]
构造如下特殊的 mask ，使得输入部分的Attention是双向的，输出部分的Attention是单向：
tensor([[[[1., 1., 1., 1., 1., 0., 0., 0.],
          [1., 1., 1., 1., 1., 0., 0., 0.],
          [1., 1., 1., 1., 1., 0., 0., 0.],
          [1., 1., 1., 1., 1., 0., 0., 0.],
          [1., 1., 1., 1., 1., 0., 0., 0.],
          [1., 1., 1., 1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 0., 0., 0., 0.],
          [1., 1., 1., 1., 0., 0., 0., 0.],
          [1., 1., 1., 1., 0., 0., 0., 0.],
          [1., 1., 1., 1., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1., 0., 0., 0.],
          [1., 1., 1., 1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 1., 1., 1., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[1., 1., 1., 0., 0., 0., 0., 0.],
          [1., 1., 1., 0., 0., 0., 0., 0.],
          [1., 1., 1., 0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1., 0., 0., 0.],
          [1., 1., 1., 1., 1., 1., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.]]]])
'''
import torch

token_type_ids_for_mask = torch.tensor([[1,1,1,1,1,0,0,0],[1,1,1,1,0,0,0,-1],[1,1,1,0,0,0,-1,-1]])
mask = torch.ones((1, 1, 8, 8), dtype=torch.float32).tril()

t1 = token_type_ids_for_mask.unsqueeze(1).unsqueeze(2).float()
t2 = (token_type_ids_for_mask != -1).unsqueeze(1).unsqueeze(3).float()

mask = ((mask+t1)*t2 > 0).float()
print(mask)